import torch
import torch.nn as nn
import copy
import dgl
import numpy as np
import json
from typing import List, Union
import torch
import torch.nn as nn
import scipy.sparse as sp
from scipy.sparse import coo_matrix
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import Dataset, DataLoader
import math
import networkx as nx
from networkx.algorithms.community import greedy_modularity_communities
import community as community_louvain




class VAEPredictor(nn.Module):
    def __init__(self, config):
        """
        Bias predictor.
        """
        super().__init__()
        self.hidden_size = config['embedding_size']
        self.embd_size = config['embedding_size']
        self.user_encode = nn.Sequential(
            nn.Linear(2 * self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size))
        self.loss = MSELoss(reduction="sum")

        self.item_encode = nn.Sequential(
            nn.Linear(2 * self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size))
        self.loss = nn.BCELoss()
        self.sigmoid=nn.Sigmoid()

    def forward(self,
                mode,
                user_all_embeddings,
                item_all_embeddings,
                user_latent,
                item_latent,
                label_u_interactions,
                label_i_interactions
                ):
        # inputs: user的decoder的output和user的latents
        # 需要挑出来哪些users被mask掉了
        # label_u_interactions,label_i_interactions: 被mask掉的边，需要两两之间的相似度最大
        if mode=='train':
           u_embeddings = torch.cat((user_all_embeddings[label_u_interactions], user_latent[label_u_interactions]), dim=-1)
           u_embeddings = self.user_encode(u_embeddings)
           i_embeddings = torch.cat((item_all_embeddings[label_i_interactions], item_latent[label_i_interactions]), dim=-1)
           i_embeddings = self.item_encode(i_embeddings)
           loss = self.sigmoid(torch.mul(u_embeddings, i_embeddings).sum(dim=1))
           return loss

        elif mode=='generate':
            u_embeddings = torch.cat((user_all_embeddings[label_u_interactions], user_latent[label_u_interactions]),
                                     dim=-1)
            u_embeddings = self.user_encode(u_embeddings)
            i_embeddings = torch.cat((item_all_embeddings, item_latent),
                                     dim=-1)
            i_embeddings = self.item_encode(i_embeddings)
            return u_embeddings,i_embeddings

class VAEReSample(nn.Module):
    def __init__(self, config) -> None:
        """
        The resample layer in VAE.
        """
        super().__init__()
        self.hidden_size = config['embedding_size']
        self.embd_size = config['embedding_size']

        self.user_embedding_mean = nn.Sequential(
            nn.Linear(self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embd_size),
        )
        self.item_embedding_mean = nn.Sequential(
            nn.Linear(self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embd_size),
        )
        self.user_embedding_logstd = nn.Sequential(
            nn.Linear(self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embd_size),
        )
        self.item_embedding_logstd = nn.Sequential(
            nn.Linear(self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embd_size),
        )
    def forward(self,
                training,
                user_embeddings,
                item_embeddings,
                ):
        user_embeddings_mean = self.user_embedding_mean(user_embeddings)
        user_embeddings_logstd = torch.clamp_max(self.user_embedding_logstd(user_embeddings), max=10)
        user_kl_loss = - 0.5 * torch.sum(
            1.0 + user_embeddings_logstd - user_embeddings_mean * user_embeddings_mean - torch.exp(
                user_embeddings_logstd))
        if training==True:
            user_embeddings_epsilon = torch.randn_like(user_embeddings_mean, device=user_embeddings.device)
            user_embeddings = user_embeddings_mean + torch.exp(user_embeddings_logstd / 2) * user_embeddings_epsilon
        else:
            user_embeddings = user_embeddings_mean
        item_embeddings_mean = self.item_embedding_mean(item_embeddings)
        item_embeddings_logstd = torch.clamp_max(self.item_embedding_logstd(item_embeddings), max=10)
        item_kl_loss = - 0.5 * torch.sum(
            1.0 + item_embeddings_logstd - item_embeddings_mean * item_embeddings_mean - torch.exp(
                item_embeddings_logstd))
        if training==True:
            item_embeddings_epsilon = torch.randn_like(item_embeddings_mean, device=item_embeddings.device)
            item_embeddings = item_embeddings_mean + torch.exp(item_embeddings_logstd / 2) * item_embeddings_epsilon
        else:
            item_embeddings = item_embeddings_mean
        return user_embeddings, item_embeddings, user_kl_loss, item_kl_loss



class CDRVAE(nn.Module):
    def __init__(self, config,total_num_users,total_num_items):
        """
        VAE model for generate edges for CDR tasks
        """
        super().__init__()
        self.config=config
        self.total_num_users=total_num_users
        self.total_num_items=total_num_items
        self.hidden_size = config['embedding_size']
        self.embd_size = config['embedding_size']
        self.user_encode = nn.Sequential(
            nn.Linear(self.total_num_users, config['embedding_size']),
            nn.ReLU(),
            nn.Linear(config['embedding_size'], config['embedding_size'])
        )
        self.item_encode = nn.Sequential(
            nn.Linear(self.total_num_items, config['embedding_size']),
            nn.ReLU(),
            nn.Linear(config['embedding_size'], config['embedding_size'])
        )
        self.loss = nn.BCELoss()
        self.sigmoid=nn.Sigmoid()
        self.resample_layer = VAEReSample(self.config)
        self.predictor = VAEPredictor(self.config)
        self.n_layers=config['generate_layers']
    def LightGCN(self,all_embeddings,norm_adj_matrix):
        embeddings_list = [all_embeddings]
        for layer_idx in range(self.n_layers):
            all_embeddings = torch.sparse.mm(norm_adj_matrix, all_embeddings)
            embeddings_list.append(all_embeddings)
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
        user_num = self.total_num_users
        item_num = self.total_num_items
        user_all_embeddings, item_all_embeddings = torch.split(
            lightgcn_all_embeddings, [user_num, item_num]
        )
        return user_all_embeddings, item_all_embeddings

    def forward(self,
                all_embeddings,
                label_u_interactions,
                label_i_interactions,
                masked_norm_adj_matrix,
                merge_norm_adj_matrix,
                mode):
        user_all_embeddings, item_all_embeddings = self.LightGCN(all_embeddings,
                                                                 merge_norm_adj_matrix)
        # 下面得到被mask掉的users和items的index
        if mode=='train':
           masked_user_all_embeddings, masked_item_all_embeddings = self.LightGCN(all_embeddings,
                                                                                   masked_norm_adj_matrix)
           user_latent, item_latent, user_kl_loss, item_kl_loss = self.resample_layer(True,
                                                                                       user_all_embeddings,
                                                                                       item_all_embeddings)
           loss = self.predictor(
           mode,
           user_latent,
           item_latent,
           masked_user_all_embeddings,
           masked_item_all_embeddings,
           label_u_interactions,
           label_i_interactions)
           loss=loss.sum()+user_kl_loss+item_kl_loss

           return loss

        elif mode=='generate':
            user_latent, item_latent, user_kl_loss, item_kl_loss = self.resample_layer(False,
                                                                                       user_all_embeddings,
                                                                                       item_all_embeddings)

            u_embeddings, i_embeddings = self.predictor(
            mode,
            user_latent,
            item_latent,
            user_all_embeddings,
            item_all_embeddings,
            label_u_interactions,
            1)
            return u_embeddings, i_embeddings



